import Config.globalconfig as Config
import torch
from LFCC_RESNET40000.model import Model_zoo
import os
import soundfile as sf
import ASVBaselineTool.feature_tools as ft

model_path = Config.RESNETLCNN40000_MS_CHOICE_PATH
labelPath = Config.TLABEL
datasetPath = Config.TPATH

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
model = Model_zoo("ResNet18")
model = (model).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print('Model loaded : {}'.format(model_path))

with open(labelPath, "r") as f:
    Labels = f.readlines()
AFnames=[item.split(" ")[0] for item in Labels]
ALabels = [item.split(" ")[-1].strip() for item in Labels]
LabelMeta={}
for f, l in zip(AFnames, ALabels):
    if l == "genuine":
        l = torch.Tensor([0.0, 1.0])
    elif l == "fake":
        l = torch.Tensor([1.0, 0.0])
    LabelMeta[f] = l


# print(LabelMeta)
allff = os.listdir(datasetPath)
a=0
b=0
for item in allff:
    y = LabelMeta[item]
    init_y = y.max(0, keepdim=True)[1].item()
    if init_y == 1:
        print(init_y)

    src_wav, sr = sf.read(os.path.join(datasetPath, item))
    src_wav = torch.Tensor(src_wav)
    adv_feature = ft.get_LFCC_70(src_wav)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    batch_x = model(adv_feature.cuda())
    # print(batch_x)
    init_pred = torch.max(batch_x, dim=1)[1].item()

    # print(init_pred)
    # print(init_y)
    if init_pred != init_y:
        a +=1
    else:
        b +=1
print(b/(a+b))